import cdtools
from matplotlib import pyplot as plt
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import center_probe, center_probe_fourier
from cdtools.tools import plotting as p

base_dir = '/home/abe/Dropbox (MIT)/Photon Scattering Group/Data/Experiments/FEL/Fermi/DiProI_202203/'

ptycho=True

if ptycho:
    # This is the single-shot ptychography dataset 
    filename = 'B_Siemens/Calibration_LowerRPI_Spiral_012_OF_PreProc.cxi'
else:
    # This is the single-shot RPI dataset
    filename = 'B_Siemens/SS_Singlepoint_RPI_x1shot_005__PreProc.cxi'

dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(base_dir + filename,
                                                    cut_zeros=False)

# Sometimes the translations have more data than the patterns, so we need to
# cut it down to size
dataset.patterns = dataset.patterns[:-1,:,:]

dataset.translations = dataset.translations[:dataset.patterns.shape[0],:]
dataset.intensities = dataset.intensities[:dataset.patterns.shape[0]]

# We remove the dead images (no intensity from the FEL)
dead_idx = [-1] + t.argwhere(dataset.intensities<=0).numpy().ravel().tolist()

if ptycho:
    dead_idx += [50, 186] # hand-selected bad frames

dead_idx = sorted(dead_idx)
dead_idx += [None]
print(dead_idx)
live_ranges = list(zip(np.array(dead_idx[:-1])+1,dead_idx[1:]))

dataset.patterns = t.cat([dataset.patterns[start:end] for start, end
                          in live_ranges], dim=0)
dataset.translations = t.cat([dataset.translations[start:end] for start, end
                          in live_ranges], dim=0)
dataset.intensities = t.cat([dataset.intensities[start:end] for start, end
                          in live_ranges], dim=0)

dataset.inspect(logarithmic=True)

if ptycho:
    dataset.to_cxi('data/single_shot_ptycho.cxi')
else:
    dataset.to_cxi('data/single_shot_rpi.cxi')

plt.show()
